# implementation of Trust Region Newton-CG algorithm

import os
import time
import datetime
import torch
torch.set_default_dtype(torch.float64)
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import pickle
import numpy as np
import math
import torch.optim as optim
from torch.nn.parameter import Parameter
from collections import OrderedDict
import sys

class SparseTRCG:
    def __init__(self,data,model,device,radius,precondition,eval_BS,\
                 cgopttol=1e-7,c0tr=0.2,c1tr=0.25,c2tr=0.75,t1tr=0.25,t2tr=2.0,radius_max=5.0,\
                 radius_initial=0.1):
        
        self.data = data
        self.model = model
        self.device = device
        self.cgopttol = cgopttol
        self.c0tr = c0tr
        self.c1tr = c1tr
        self.c2tr = c2tr
        self.t1tr = t1tr
        self.t2tr = t2tr
        self.radius_max = radius_max
        self.radius_initial = radius_initial
        self.radius = radius
        self.cgmaxiter = 60 
        self.iterationCounterForAdamTypePreconditioning = 0
        self.precondition = precondition
        if self.precondition != 0:
            self.DiagPrecond = [w.data*0.0 for w in self.model.parameters()]
            self.DiagScale = 0.0
        self.eval_BS = 2000000 
        self.eval_BS = eval_BS
        self.grad_bulk = [] 
        self.BS = 0 
        self.newLOSS = 0.0
                
    def findroot(self,x,p):
        
        aa = 0.0
        bb = 0.0
        cc = 0.0
    
        for e in range(len(x)):
            aa += (p[e]*p[e]).sum()
            bb += (p[e]*x[e]).sum()
            cc += (x[e]*x[e]).sum()
        
        bb = bb*2.0
        cc = cc - self.radius**2
    
        alpha = (-2.0*cc)/(bb+(bb**2-(4.0*aa*cc)).sqrt())

        return alpha.item()
    
    
    def computeListNorm(self,lst):
        return np.sum([(ri.data*ri.data).sum().item() for ri in lst])**0.5
    
    def computeListNormSq(self,lst):
        return  np.sum([ (ri.data*ri.data).sum().item() for ri in lst]) 
    
    def CGSolver(self,loss_grad,cgloop):
        
        #
        # update preconditioner 
        #
        if self.precondition != 0:
            if self.precondition < 6:
                self.SquaredPreconditioner=[1.0/torch.sqrt(di)*1.0  for di in self.DiagPrecond]
            elif self.precondition == 6:
                self.SquaredPreconditioner=[1.0/ (di)*1.0  for di in self.DiagPrecond]
            elif self.precondition in [7,8,9]:
                scl = np.sqrt(   (1 - np.power(self.DiagScale, self.iterationCounterForAdamTypePreconditioning)) )
                self.SquaredPreconditioner=[1.0/torch.sqrt(di)*scl  for di in self.DiagPrecond]
                
        #
        # use previously computed loss_grad for initial setup
        #
        cg_iter = 0 # iteration counter
        x0 = [i.data*0 for i in self.model.parameters()]
        if self.precondition == 0:
            r0 = [i.data+0.0 for i in loss_grad]  # set initial residual to gradient
            p0 = [-i.data+0.0 for i in loss_grad] # set initial conjugate direction to -r0
            self.cgopttol = self.computeListNormSq(loss_grad)
            self.cgopttol = self.cgopttol**0.5
            self.cgopttol = (min(0.5,self.cgopttol**0.5))*self.cgopttol
        else: 
            r0 = [(i.data+0.0)*pr.data for i, pr in zip(loss_grad,self.SquaredPreconditioner)]
            p0 = [-(i.data+0.0)*pr.data for i, pr in zip(loss_grad,self.SquaredPreconditioner)]
            self.cgopttol = self.computeListNormSq(r0)
            self.cgopttol = self.cgopttol**0.5
            self.cgopttol = (min(0.5,self.cgopttol**0.5))*self.cgopttol
            
        cg_term = 0
        j = 0

        #
        # CG iterations
        #
        
        while 1:
            j+=1
            self.CG_STEPS_TOOK = j
            # if CG does not solve model within max allowable iterations
            if j > self.cgmaxiter:
                j=j-1
                p1 = x0
                break
                
            #
            # compute gradient and Hessian-vector product iteratively for each CG iteration
            #
            cgloopstart = time.time()
            
            # hessian vector product
            if self.precondition == 0:
                loss_grad_direct = np.sum([(gi*si).sum() for gi, si in zip(loss_grad,p0)])
                HP = torch.autograd.grad(loss_grad_direct,self.model.parameters(),retain_graph=True)
            else:
                loss_grad_direct \
                = np.sum([(gi*(si*pr.data)).sum() for gi, si, pr in zip(loss_grad,p0,self.SquaredPreconditioner)])
                HP = torch.autograd.grad(loss_grad_direct,self.model.parameters(),retain_graph=True)
                HP = [ g*pr.data for g, pr in zip(HP,self.SquaredPreconditioner)]
            
            PHP = np.sum([(Hpi*p0i).sum().item() for Hpi, p0i in zip(HP,p0)])
            
            cgloop.append(time.time() - cgloopstart)
            
            # if nonpositive curvature detected, go for the boundary of trust region
            if PHP <= 0:
                tau = self.findroot(x0,p0)
                p1 = [xi+tau*p0i  for xi, p0i in zip(x0,p0)]
                cg_term = 1
                break
            
            # if positive curvature
            # vector product
            rr0 = self.computeListNormSq(r0)

            # update alpha
            alpha = rr0/PHP
            
            x1 = [xi+alpha*pi for xi,pi in zip(x0,p0)]
            norm_x1 = self.computeListNorm(x1)
            
            # if norm of the updated x1 > radius
            if norm_x1 >= self.radius:
                tau = self.findroot(x0,p0)
                p1 = [xi+tau*pi for xi,pi in zip(x0,p0)]
                cg_term = 2
                break
    
            # update residual
            r1 = [ri+alpha*HPi for ri, HPi in zip(r0, HP)]
            norm_r1 = self.computeListNorm(r1)
    
            if norm_r1 < self.cgopttol:
                p1 = x1
                cg_term = 3
                break
    
            rr1 = self.computeListNormSq(r1)
            beta = rr1/rr0
    
            # update conjugate direction for next iterate
            p1 = [-ri+beta*pi for ri,pi in zip(r1,p0)]
    
            p0 = p1
            x0 = x1
            r0 = r1
    

        cg_iter = j
        norm_p1 = self.computeListNorm(p1)
        if self.precondition != 0:
            p1 = [pi*pr.data for pi,pr in zip(p1, self.SquaredPreconditioner)]
        d = p1

        return d,cg_iter,cg_term,cgloop,norm_p1
    
    def step(self,oldloss,loss_grad,sample=None):
        f_cost=0 # fun eval
        g_cost=0 # grad and/or Hv 
        
        for gi in loss_grad:
            if gi.grad_fn is None:
                raise ValueError('no grad_fn found in %s'%repr(gi))
        
        CGITER = 0.0

        w0 = [a.data+0.0 for a in self.model.parameters()]
        
        firstloopstart = time.time()
        
        firstloop = [time.time() - firstloopstart]
       
        #
        # update preconditioner
        #
        if self.precondition == 1:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*gi.data*gi.data)
                di.data[di.data==0]+=1.0
            self.DiagScale = 0.95
        if self.precondition == 2:     
            self.DiagScale = 0.001
            self.exponent = 0.75 
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_((gi.data*gi.data + self.DiagScale)**self.exponent)
        if self.precondition == 3:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(1.0-self.DiagScale+self.DiagScale*gi.data*gi.data)
            self.DiagScale = 1e-2
        if self.precondition == 4:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*gi.data*gi.data)
                di.data[di.data==0]+=1.0
            self.DiagScale = 0.99
        if self.precondition == 5:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*gi.data*gi.data)
                di.data[di.data==0]+=1.0
            self.DiagScale = 0.90
        if self.precondition == 6:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*torch.abs(gi.data))
                di.data[di.data==0]+=1.0
            self.DiagScale = 0.95

        if self.precondition == 6:
            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*torch.abs(gi.data))
                di.data[di.data==0]+=1.0
            self.DiagScale = 0.95
            
        if self.precondition in [7,8,9]:
            if self.precondition == 7:
                self.DiagScale = 0.99
            if self.precondition == 8:
                self.DiagScale = 0.95
            if self.precondition == 9:
                self.DiagScale = 0.90
                
            self.iterationCounterForAdamTypePreconditioning += 1

            for gi, di in zip(loss_grad,self.DiagPrecond):
                di.data.copy_(di.data*self.DiagScale+(1-self.DiagScale)*torch.abs(gi.data))
                di.data[di.data==0]+=1.0
               
        secondloop = [] 
        thirdloop = []
        cgloop=[]

        # Conjugate Gradient Method
        d, cg_iter, cg_term, cgloop, norm_p1 = self.CGSolver(loss_grad,cgloop)

        CGITER += cg_iter
        g_cost+=cg_iter

        #
        # compute gradient and hessian-vector product for determining ratio
        #
        secondloopstart = time.time()

        loss_grad_direct = np.sum([(gi*di).sum() for gi, di in zip(loss_grad,d)])
        Hd = torch.autograd.grad(loss_grad_direct,self.model.parameters(),retain_graph=True)
        DHD = np.sum([(Hdi*di).sum().item() for Hdi, di in zip(Hd,d)])
        GD = np.sum([(gi.data*di).sum().item() for gi,di in zip(loss_grad,d)])
        g_cost+=1.0

        secondloop.append(time.time()-secondloopstart)

        # update model parms
        with torch.no_grad():
            for wi,di in zip(self.model.parameters(),d):
                wi.add_(di+0.0)

        thirdloopstart = time.time()
        self.newLOSS,_ = self.model.LossGrad(self.data,sample=sample,eval_BS=self.eval_BS)
        f_cost+=1.0
        thirdloop.append(time.time()-thirdloopstart)
        norm_d = self.computeListNorm(d)

        numerator = oldloss - self.newLOSS

        denominator = -GD - 0.5*DHD

        # ratio
        rho = numerator/denominator
        update = 3 # default reject
        if rho < self.c1tr: # shrink radius
            self.radius = self.t1tr*self.radius
            update = 0
        if rho > self.c2tr and np.abs(norm_p1 - self.radius) < 1e-10: # enlarge radius
            self.radius = min(self.t2tr*self.radius,self.radius_max)
            update = 1
        # otherwise, radius remains the same
        if rho <= self.c0tr or np.isnan(rho): # reject d
            update = 3
            with torch.no_grad():
                for wi,w0i in zip(self.model.parameters(),w0):
                    wi.set_(w0i+0.0)
                    
        return d, rho, update, CGITER, cg_term, loss_grad, norm_d, norm_p1, numerator, denominator,\
               firstloop, secondloop, thirdloop, cgloop, f_cost, g_cost
    